import json
import re
from typing import List, Dict, Any
from tqdm import tqdm
from latex2sympy2 import latex2sympy
from sympy import simplify

def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
    with open(file_path, 'r', encoding='utf-8') as f:
        return [json.loads(line) for line in f]

def save_jsonl(data: List[Dict[str, Any]], file_path: str):
    with open(file_path, 'w', encoding='utf-8') as f:
        for item in data:
            json.dump(item, f)
            f.write('\n')

def extract_final_answer(response: str) -> str:
    boxed_pattern = r"\\boxed{(.*)}"
    boxed_matches = re.findall(boxed_pattern, response)
    if boxed_matches:
        # Return the last boxed answer
        return boxed_matches[-1].strip()
    
    # Fail case
    lines = [line.strip() for line in response.split('\n') if line.strip()]
    return lines[-1] if lines else response.strip()

def is_correct(model_answer: str, ground_truth: str) -> bool:
    try:

        model_sympy = latex2sympy(model_answer)
        truth_sympy = latex2sympy(ground_truth)
        
        model_simplified = simplify(model_sympy)
        truth_simplified = simplify(truth_sympy)
        
        difference = model_simplified - truth_simplified
        simplified_diff = simplify(difference)
    
        return simplified_diff == 0
    except Exception as e:
        return model_answer.strip() == ground_truth.strip()


def evaluate_responses(data: List[Dict[str, Any]]) -> Dict[str, Any]:
    results = {
        "total": len(data),
        "best_of_n": 0,
        "average": 0,
        "majority_vote": 0
    }
    
    for item in tqdm(data, desc="Evaluating responses"):
        ground_truth = item.get("answer", "").strip()
        responses = item.get("model_responses", [])
        
        if not responses or not ground_truth:
            continue
        
        binary_eval_results = []
        
        for response in responses:
            extracted_answer = extract_final_answer(response)
            correct = is_correct(extracted_answer, ground_truth)
            binary_eval_results.append(1 if correct else 0)
        
        item["eval_results"] = binary_eval_results
        results["best_of_n"] += any(binary_eval_results)
        results["average"] += sum(binary_eval_results) / len(binary_eval_results)
        results["majority_vote"] += (sum(binary_eval_results) > len(binary_eval_results) / 2)
    
    # Calculate averages
    for key in ["best_of_n", "average", "majority_vote"]:
        if results["total"] > 0:
            results[key] /= results["total"]
    
    return results

if __name__ == '__main__':
    input_file = "math_Qwen.jsonl"
    output_file = "eval_math_Qwen.jsonl"
    print(f"Loading data from {input_file}...")
    data = load_jsonl(input_file)
    
    print("Evaluating model responses...")
    results = evaluate_responses(data)
    save_jsonl(data, output_file)
    
    print("\nEvaluation Results:")
    print(f"Total samples: {results['total']}")
    print(f"Best of N accuracy: {results['best_of_n']:.4f}")
    print(f"Average accuracy: {results['average']:.4f}")
    print(f"Majority vote accuracy: {results['majority_vote']:.4f}")
